# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Callable

import pytest

from burr.core import Action, Condition, Result, State, default
from burr.core.graph import GraphBuilder, _validate_actions, _validate_transitions


# TODO -- share versions among tests, this is duplicated in test_application.py
class PassedInAction(Action):
    def __init__(
        self,
        reads: list[str],
        writes: list[str],
        fn: Callable[..., dict],
        update_fn: Callable[[dict, State], State],
        inputs: list[str],
        tags: list[str] = None,
    ):
        super(PassedInAction, self).__init__()
        self._reads = reads
        self._writes = writes
        self._fn = fn
        self._update_fn = update_fn
        self._inputs = inputs
        self._tags = tags

    def run(self, state: State, **run_kwargs) -> dict:
        return self._fn(state, **run_kwargs)

    @property
    def inputs(self) -> list[str]:
        return self._inputs

    def update(self, result: dict, state: State) -> State:
        return self._update_fn(result, state)

    @property
    def reads(self) -> list[str]:
        return self._reads

    @property
    def writes(self) -> list[str]:
        return self._writes

    @property
    def tags(self) -> list[str]:
        return self._tags or []


def test__validate_transitions_correct():
    _validate_transitions(
        [("counter", "counter", Condition.expr("count < 10")), ("counter", "result", default)],
        {"counter", "result"},
    )


def test__validate_transitions_missing_action():
    with pytest.raises(ValueError, match="not found"):
        _validate_transitions(
            [
                ("counter", "counter", Condition.expr("count < 10")),
                ("counter", "result", default),
            ],
            {"counter"},
        )


def test__validate_transitions_redundant_transition():
    with pytest.raises(ValueError, match="redundant"):
        _validate_transitions(
            [
                ("counter", "counter", Condition.expr("count < 10")),
                ("counter", "result", default),
                ("counter", "counter", default),  # this is unreachable as we already have a default
            ],
            {"counter", "result"},
        )


def test__validate_actions_valid():
    _validate_actions([Result("test")])


def test__validate_actions_empty():
    with pytest.raises(ValueError, match="at least one"):
        _validate_actions([])


def test__validate_actions_duplicated():
    with pytest.raises(ValueError, match="duplicated"):
        _validate_actions([Result("test"), Result("test")])


base_counter_action = PassedInAction(
    reads=["count"],
    writes=["count"],
    fn=lambda state: {"count": state.get("count", 0) + 1},
    update_fn=lambda result, state: state.update(**result),
    inputs=[],
    tags=["tag1", "tag2"],
)


def test_graph_builder_builds():
    graph = (
        GraphBuilder()
        .with_actions(counter=base_counter_action, result=Result("count"))
        .with_transitions(
            ("counter", "counter", Condition.expr("count < 10")), ("counter", "result")
        )
        .build()
    )
    assert len(graph.actions) == 2
    assert len(graph.transitions) == 2


def test_graph_builder_with_graph():
    graph1 = (
        GraphBuilder()
        .with_actions(counter=base_counter_action)
        .with_transitions(("counter", "counter", Condition.expr("count < 10")))
        .build()
    )
    graph2 = (
        GraphBuilder()
        .with_actions(counter2=base_counter_action)
        .with_transitions(("counter2", "counter2", Condition.expr("count < 20")))
        .build()
    )
    graph = (
        GraphBuilder()
        .with_graph(graph1)
        .with_graph(graph2)
        .with_actions(result=Result("count"))
        .with_transitions(
            ("counter", "counter2"),
            ("counter2", "result"),
        )
    )
    assert len(graph.actions) == 3
    assert len(graph.transitions) == 4


def test_graph_builder_get_next_node():
    graph = (
        GraphBuilder()
        .with_actions(counter=base_counter_action, result=Result("count"))
        .with_transitions(
            ("counter", "counter", Condition.expr("count < 10")), ("counter", "result")
        )
        .build()
    )
    assert len(graph.actions) == 2
    assert len(graph.transitions) == 2
    assert graph.get_next_node(None, State({"count": 0}), entrypoint="counter").name == "counter"


def test_get_actions_by_tag():
    action_with_tags = PassedInAction(
        reads=["count"],
        writes=["count"],
        fn=lambda state: {"count": state.get("count", 0) + 1},
        update_fn=lambda result, state: state.update(**result),
        inputs=[],
        tags=["tag1", "tag2"],
    )

    action_with_tags_2 = PassedInAction(
        reads=["count"],
        writes=["count"],
        fn=lambda state: {"count": state.get("count", 0) + 1},
        update_fn=lambda result, state: state.update(**result),
        inputs=[],
        tags=["tag1", "tag3"],
    )
    graph = (
        GraphBuilder()
        .with_actions(counter1=action_with_tags, counter2=action_with_tags_2)
        .with_transitions(("counter1", "counter2"))
        .with_transitions(("counter2", "counter1"))
        .build()
    )
    # tag1 is in both actions, tag2 is in one, tag3 is in one
    assert len(graph.get_actions_by_tag("tag1")) == 2
    assert len(graph.get_actions_by_tag("tag2")) == 1
    assert len(graph.get_actions_by_tag("tag3")) == 1
    with pytest.raises(ValueError, match="not found"):
        graph.get_actions_by_tag("tag4")
